Skip to content

Conversation

@Chenyaaang
Copy link
Collaborator

@Chenyaaang Chenyaaang commented Nov 20, 2025

Description

Fix numerical issue on hybrid kv cache allocation. When we enable hybrid kv cache, at each kv cache allocation round, the block_id is different between each kv cache group, which means different layers are writing to different block_ids, so we need to create individual attention metadata for each layer, instead of using the same attention metadata for every layer.

Tests

  • unit tests in tpu_worker, tpu_runner passed
  • The results w/ vs w/o hybrid kv cache are the same when I run offline_inference.py with Gemma model. python examples/offline_inference.py --model google/gemma-3-27b-it --tensor-parallel-size 8
  • CI: https://buildkite.com/tpu-commons/tpu-inference-ci/builds/5787 all tasks are green except for lora, which I believe is an upstream change, not related to my pr.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Copy link
Collaborator

@py4 py4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These PR doesn't have any tests. Please add the following tests:

  1. e2e Correctness test: output with and without hybrid allocation is the same
  2. e2e performance test: performance with hybrid allocator is higher than without hybrid allocator
  3. unit tests for the changed python files and the runner. We need to keep coverage above 70% and we need our PRs to come with enough tests

Copy link
Collaborator

@py4 py4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this also work for JAX path? if no, can we also make JAX path work?

@Chenyaaang
Copy link
Collaborator Author

Chenyaaang commented Nov 21, 2025

Does this also work for JAX path? if no, can we also make JAX path work?

It should be backend agnostic, but to enable in Jax, we need to modify the individual jax model. Previously all jax models don't need hybrid kv cache, so it's not enabled. The numerical issue is also reported using vLLM model instead of flax nnx.

@Chenyaaang Chenyaaang closed this Nov 21, 2025
@Chenyaaang Chenyaaang reopened this Nov 21, 2025
@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Signed-off-by: Chenyaaang <[email protected]>
Signed-off-by: Chenyaaang <[email protected]>
@kyuyeunk
Copy link
Collaborator

with this PR, Ion gpt-oss, 've verified that numeric issue has been solved & also a performance issue that stemmed from numeric issues has been resolved.

Signed-off-by: Chenyaaang <[email protected]>
Signed-off-by: Chenyaaang <[email protected]>
@Chenyaaang Chenyaaang merged commit cfc7610 into main Nov 22, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants